{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import subprocess\n",
    "from com.yahoo.ml.tf import TFCluster\n",
    "import mnist_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument(\"-e\", \"--epochs\", help=\"number of epochs\", type=int, default=0)\n",
    "parser.add_argument(\"-i\", \"--images\", help=\"HDFS path to MNIST images in parallelized format\")\n",
    "parser.add_argument(\"-f\", \"--format\", help=\"example format\", choices=[\"csv\",\"tfr\"], default=\"tfr\")\n",
    "parser.add_argument(\"-m\", \"--model\", help=\"HDFS path to save/load model\", default=\"mnist_model\")\n",
    "parser.add_argument(\"-o\", \"--output\", help=\"HDFS path to save test/inference output\", default=\"predictions\")\n",
    "parser.add_argument(\"-r\", \"--readers\", help=\"number of reader/enqueue threads\", type=int, default=1)\n",
    "parser.add_argument(\"-s\", \"--steps\", help=\"maximum number of steps\", type=int, default=1000)\n",
    "parser.add_argument(\"-X\", \"--mode\", help=\"train|inference\", default=\"train\")\n",
    "parser.add_argument(\"-c\", \"--rdma\", help=\"use rdma connection\", default=False)\n",
    "parser.add_argument(\"-tb\", \"--tensorboard\", help=\"launch tensorboard process\", action=\"store_true\")\n",
    "#Number of executors you have actually launched\n",
    "num_executors = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#remove existing models and predictions\n",
    "subprocess.call([\"hadoop\", \"fs\", \"-rm\", \"-R\", \"mnist_model\"])\n",
    "subprocess.call([\"hadoop\", \"fs\", \"-rm\", \"-R\", \"predictions\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-02-09 19:38:54,066 INFO (MainThread-7122) Reserving TFSparkNodes w/ TensorBoard\n",
      "2017-02-09 19:39:05,708 INFO (MainThread-7122) TensorBoard running at: http://ip-172-31-25-197:59675\n",
      "2017-02-09 19:39:05,710 INFO (MainThread-7122) Starting TensorFlow\n"
     ]
    }
   ],
   "source": [
    "#reserve a TensorFlow cluster\n",
    "cluster = TFCluster.reserve(sc, num_executors, 1, True, TFCluster.InputMode.TENSORFLOW)\n",
    "#kick off training\n",
    "args = parser.parse_args(['--mode', 'train', \n",
    "                          '--images', 'mnist/tfr/train'])\n",
    "cluster.start(mnist_dist.map_fun, args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-02-09 19:39:28,498 INFO (MainThread-7122) Stopping TensorFlow nodes\n"
     ]
    }
   ],
   "source": [
    "#The cluster will only be shutddown when the training is actually completed.\n",
    "#It will take a few minutes.\n",
    "cluster.shutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 17 items\n",
      "-rw-r--r--   3 root supergroup        265 2017-02-09 19:40 mnist_model/checkpoint\n",
      "-rw-r--r--   3 root supergroup     142752 2017-02-09 19:39 mnist_model/graph.pbtxt\n",
      "-rw-r--r--   3 root supergroup     814164 2017-02-09 19:40 mnist_model/model.ckpt-416.data-00000-of-00001\n",
      "-rw-r--r--   3 root supergroup        372 2017-02-09 19:40 mnist_model/model.ckpt-416.index\n",
      "-rw-r--r--   3 root supergroup      56894 2017-02-09 19:40 mnist_model/model.ckpt-416.meta\n",
      "-rw-r--r--   3 root supergroup     814164 2017-02-09 19:40 mnist_model/model.ckpt-543.data-00000-of-00001\n",
      "-rw-r--r--   3 root supergroup        372 2017-02-09 19:40 mnist_model/model.ckpt-543.index\n",
      "-rw-r--r--   3 root supergroup      56894 2017-02-09 19:40 mnist_model/model.ckpt-543.meta\n",
      "-rw-r--r--   3 root supergroup     814164 2017-02-09 19:40 mnist_model/model.ckpt-669.data-00000-of-00001\n",
      "-rw-r--r--   3 root supergroup        372 2017-02-09 19:40 mnist_model/model.ckpt-669.index\n",
      "-rw-r--r--   3 root supergroup      56894 2017-02-09 19:40 mnist_model/model.ckpt-669.meta\n",
      "-rw-r--r--   3 root supergroup     814164 2017-02-09 19:40 mnist_model/model.ckpt-795.data-00000-of-00001\n",
      "-rw-r--r--   3 root supergroup        372 2017-02-09 19:40 mnist_model/model.ckpt-795.index\n",
      "-rw-r--r--   3 root supergroup      56894 2017-02-09 19:40 mnist_model/model.ckpt-795.meta\n",
      "-rw-r--r--   3 root supergroup     814164 2017-02-09 19:40 mnist_model/model.ckpt-920.data-00000-of-00001\n",
      "-rw-r--r--   3 root supergroup        372 2017-02-09 19:40 mnist_model/model.ckpt-920.index\n",
      "-rw-r--r--   3 root supergroup      56894 2017-02-09 19:40 mnist_model/model.ckpt-920.meta\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#examine the newly trained model\n",
    "print(subprocess.check_output([\"hadoop\", \"fs\", \"-ls\", \"mnist_model\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-02-09 19:41:32,734 INFO (MainThread-7122) Reserving TFSparkNodes\n",
      "2017-02-09 19:41:43,519 INFO (MainThread-7122) Starting TensorFlow\n"
     ]
    }
   ],
   "source": [
    "#reserve a TensorFlow cluster\n",
    "cluster = TFCluster.reserve(sc, num_executors, 1, False, TFCluster.InputMode.TENSORFLOW)\n",
    "#kick off inference using the trained model\n",
    "inf_args = parser.parse_args(['--mode', 'inference', \n",
    "                              '--images', 'mnist/tfr/test', \n",
    "                              '--epochs', '1'])\n",
    "cluster.start(mnist_dist.map_fun, inf_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-02-09 19:41:48,531 INFO (MainThread-7122) Stopping TensorFlow nodes\n"
     ]
    }
   ],
   "source": [
    "#The cluster will only be shutddown when the inference is actually completed\n",
    "cluster.shutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[u'7 7',\n",
       " u'2 2',\n",
       " u'1 1',\n",
       " u'0 0',\n",
       " u'4 4',\n",
       " u'1 1',\n",
       " u'4 4',\n",
       " u'9 9',\n",
       " u'5 6',\n",
       " u'9 9']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#The prediction result with (lable, prediction) for each example\n",
    "predictions = sc.textFile(\"predictions\")\n",
    "predictions.take(10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
